Temporal encoding for HGTConv #10469
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds support for Relative Temporal Encoding (RTE) to the
HGTConvlayer, as described in the original "Heterogeneous Graph Transformer" paper.Description of Changes
use_RTEflag: A new boolean argumentuse_RTEis added to theHGTConvconstructor to enable or disable temporal encoding. When enabled, it initializes aPositionalEncodingmodule.New
forwardargument: The forward method now accepts an optionaledge_time_diff_dict. This dictionary should contain a 1D tensor of time differences (∆T) for each edge type, which serves as the input to the encoding function.Input Validation: A new
_validate_inputshelper function has been added to ensure that ifuse_RTEis enabled, theedge_time_diff_dictis provided and contains a time difference tensor for every edge type. It also warns the user if they provide time data whenuse_RTEis disabled.RTE Application: In the
messagefunction, the calculated temporal encoding (temporal_features) is added to the key (k_j) and value (v_j) vectors of the source nodes. This injects the temporal information directly into the attention mechanism.Implementation Note
This implementation adds temporal encoding to the key (
k_j) and value (v_j) vectors after their projection (a deviation from the paper) to preserve the efficient, parallelized node-level computation, which would otherwise become a much slower, edge-specific operation.Tests Added
Tests have been added to validate this feature:
References
Hu, Z., Dong, Y., Wang, K., & Sun, Y. (2020). Heterogeneous Graph Transformer.
arXiv link: https://arxiv.org/abs/2003.01332